{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Training GNNs on Large Graphs\n",
    "\n",
    "We have seen the example of training GNNs on the entire graph.  However, usually our graph is very big: it could contain millions or billions of nodes and edges.  The storage required for the graph would be many times bigger if we consider node and edge features.  If we want to utilize GPUs for faster computation, we would notice that full graph training is often impossible on GPUs because our graph and features cannot fit into a single GPU.  Not to mention that the node representation of intermediate layers are also stored for the sake of backpropagation.\n",
    "\n",
    "To get over this limit, we employ two methodologies:\n",
    "\n",
    "1. Stochastic training on graphs.\n",
    "2. Neighbor sampling on graphs."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### GraphSAGE Recap\n",
    "\n",
    "In previous session, we have discussed GraphSAGE model. The output representation $h_v^k$ of node $v$ from the $k$th layer is simply computed by:\n",
    "\n",
    "$$h_{N(v)}^k \\leftarrow AGGREGATE_k({h_u^{k-1}, \\forall u \\in N(v)})$$\n",
    "\n",
    "$$h_v^k \\leftarrow \\sigma(W^k CONCAT(h_v^{k-1}, h_{N(v)}^k))$$\n",
    "\n",
    "Note: the input of a GraphSage layer includes the neighbors' representation from the previous layer as well as the destination nodes' representation from the previous layer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import dgl\n",
    "import dgl.function as fn\n",
    "from dgl.nn.pytorch import conv as dgl_conv\n",
    "\n",
    "import torch\n",
    "from torch import nn\n",
    "import torch.nn.functional as F"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Mini-batch Construction from a Graph\n",
    "\n",
    "For stochastic training, we want to split training data into small mini-batches and only put necessary information into GPU for each step of training. In case of node classification, we want to split the labeld nodes into mini-batches. Let take a deep look of what information is necessary for a mini-batch of nodes."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's use NetworkX to construct the toy graph. *Note*: please don't use NetworkX when you want to scale to large graphs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# A small graph\n",
    "\n",
    "import networkx as nx\n",
    "\n",
    "example_graph_nx = nx.Graph(\n",
    "    [(0, 2), (0, 4), (0, 6), (0, 7), (0, 8), (0, 9), (0, 10),\n",
    "     (1, 2), (1, 3), (1, 5), (2, 3), (2, 4), (2, 6), (3, 5),\n",
    "     (3, 8), (4, 7), (8, 9), (8, 11), (9, 10), (9, 11)])\n",
    "\n",
    "example_graph = dgl.graph(example_graph_nx)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![](assets/graph.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Single layer\n",
    "\n",
    "If we wish to compute the output representation of node 4 and 6 with a GraphSAGE layer, we actually need the input feature of node 4 and 6 themselves, as well as their neighbors (node 7, 0 and 2):\n",
    "\n",
    "To construct a mini-batch with neighbor sampling, we can use DGL API: `dgl.sample_neighbors`, that takes in a set of nodes and returns a graph consisting of a specified number of edges going to one of the given nodes.  Such a graph can exactly describe the computation dependency above."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sampled_node_batch = torch.LongTensor([4, 6])   # These are the nodes whose outputs are to be computed\n",
    "sampled_graph = dgl.sampling.sample_neighbors(example_graph, sampled_node_batch, 2)\n",
    "print('|V|={}, |E|={}'.format(sampled_graph.number_of_nodes(), sampled_graph.number_of_edges()))\n",
    "src, dst = sampled_graph.all_edges()\n",
    "for s, d in zip(src, dst):\n",
    "    print(s.numpy(), d.numpy())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "DGL further provides a bipartite structure *block* to better reflect this data structure. A sub graph can be easily converted to a block with function `dgl.to_block`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sampled_block = dgl.to_block(sampled_graph, sampled_node_batch)\n",
    "\n",
    "def print_block_info(sampled_block):\n",
    "    print('#source:', sampled_block.number_of_src_nodes())\n",
    "    sampled_input_nodes = sampled_block.srcdata[dgl.NID]\n",
    "    print('Node ID of source nodes in original graph:', sampled_input_nodes)\n",
    "\n",
    "    sampled_output_nodes = sampled_block.dstdata[dgl.NID]\n",
    "    print('#destination:', sampled_block.number_of_dst_nodes())\n",
    "    print('Node ID of destination nodes in original graph:', sampled_output_nodes)\n",
    "\n",
    "    sampled_block_edges_src, sampled_block_edges_dst = sampled_block.all_edges()\n",
    "    print('edges in local node Ids')\n",
    "    for s, d in zip(sampled_block_edges_src, sampled_block_edges_dst):\n",
    "        print(s.numpy(), d.numpy())\n",
    "    # We need to map the src and dst node IDs in the blocks to the node IDs in the original graph.\n",
    "    sampled_block_edges_src_mapped = sampled_input_nodes[sampled_block_edges_src]\n",
    "    sampled_block_edges_dst_mapped = sampled_output_nodes[sampled_block_edges_dst]\n",
    "    print('edges in the original node Ids')\n",
    "    for s, d in zip(sampled_block_edges_src_mapped, sampled_block_edges_dst_mapped):\n",
    "        print(s.numpy(), d.numpy())\n",
    "    \n",
    "print_block_info(sampled_block)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Multiple Layers\n",
    "\n",
    "Now we wish to compute the output of node 4 and 6 from a 2-layer GraphSAGE.  This requires the input features of not only the nodes themselves and their neighbors, but also the neighbors of these neighbors.\n",
    "\n",
    "To compute the 2-layer output of node 4 and 6, we first need to obtain the 1-layer output of node 4 and 6, as well as the neighbors.  To obtain the 1-layer output of all these nodes, we again need the input feature of these nodes as well as *their* neighbors.\n",
    "\n",
    "We can see that the generation of computation dependency for multi-layer GNNs is a bottom-up process: we start from the output layer, and grows the node set towards the input layer.\n",
    "\n",
    "The following code directly returns the list of blocks as the computation dependency generation for multi-layer GNNs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class NeighborSampler(object):\n",
    "    def __init__(self, g, num_fanouts):\n",
    "        \"\"\"\n",
    "        num_fanouts : list of fanouts on each layer.\n",
    "        \"\"\"\n",
    "        self.g = g\n",
    "        self.num_fanouts = num_fanouts\n",
    "        \n",
    "    def sample(self, seeds):\n",
    "        seeds = torch.LongTensor(seeds)\n",
    "        blocks = []\n",
    "        for fanout in reversed(self.num_fanouts):\n",
    "            # We simply switch from in_subgraph to sample_neighbors for neighbor sampling.\n",
    "            if fanout >= self.g.number_of_nodes():\n",
    "                sampled_graph = dgl.in_subgraph(self.g, seeds)\n",
    "            else:\n",
    "                sampled_graph = dgl.sampling.sample_neighbors(self.g, seeds, fanout)\n",
    "            \n",
    "            sampled_block = dgl.to_block(sampled_graph, seeds)\n",
    "            seeds = sampled_block.srcdata[dgl.NID]\n",
    "            # Because the computation dependency is generated bottom-up, we prepend the new block instead of\n",
    "            # appending it.\n",
    "            blocks.insert(0, sampled_block)\n",
    "            \n",
    "        return blocks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "block_sampler = NeighborSampler(example_graph, [2, 2])\n",
    "sampled_blocks = block_sampler.sample(sampled_node_batch)\n",
    "\n",
    "print('#blocks:', len(sampled_blocks))\n",
    "print('Block for first layer')\n",
    "print('---------------------')\n",
    "print_block_info(sampled_blocks[0])\n",
    "print()\n",
    "print('Block for second layer')\n",
    "print('----------------------')\n",
    "print_block_info(sampled_blocks[1])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Minibatch training for 2-layer GraphSage\n",
    "\n",
    "### GraphSAGE on blocks\n",
    "\n",
    "The sampled block is ensantially a bipartite graph. We have seen in previous example that DGL's built-in class `SAGEConv` works perfectly on whole graph. Does it also function properly on a *Block*? The answer is yes. Acutally all of DGL's neural network layers support working on both homogeneous graphs and bipartite graphs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import dgl.nn as dglnn\n",
    "\n",
    "class SAGENet(nn.Module):\n",
    "    def __init__(self, n_layers, in_feats, out_feats, hidden_feats=None):\n",
    "        super().__init__()\n",
    "        self.convs = nn.ModuleList()\n",
    "        \n",
    "        if hidden_feats is None:\n",
    "            hidden_feats = out_feats\n",
    "        \n",
    "        if n_layers == 1:\n",
    "            self.convs.append(dglnn.SAGEConv(in_feats, out_feats, 'mean'))\n",
    "        else:\n",
    "            self.convs.append(dglnn.SAGEConv(in_feats, hidden_feats, 'mean', activation=F.relu))\n",
    "            for i in range(n_layers - 2):\n",
    "                self.convs.append(dglnn.SAGEConv(hidden_feats, hidden_feats, 'mean', activation=F.relu))\n",
    "            self.convs.append(dglnn.SAGEConv(hidden_feats, out_feats, 'mean'))\n",
    "        \n",
    "    def forward(self, blocks, input_features):\n",
    "        \"\"\"\n",
    "        blocks : List of blocks generated by block sampler.\n",
    "        input_features : Input features of the first block.\n",
    "        \"\"\"\n",
    "        h = input_features\n",
    "        for layer, block in zip(self.convs, blocks):\n",
    "            h = self.propagate(block, h, layer)\n",
    "        return h\n",
    "    \n",
    "    def propagate(self, block, src_feats, layer):\n",
    "        # Because GraphSAGE requires not only the features of the neighbors, but also the features\n",
    "        # of the output nodes themselves on the current layer, we need to copy the output node features\n",
    "        # from the input side to the output side ourselves to make GraphSAGE work correctly.\n",
    "        # The output nodes of a block are guaranteed to appear the first in the input nodes, so we can\n",
    "        # conveniently write like this:\n",
    "        dst_feats = src_feats[:block.number_of_dst_nodes()]\n",
    "        return layer(block, (src_feats, dst_feats))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Inference with mini-batch\n",
    "\n",
    "Inference can be computed in a mini-batch fashion. For a multi-layer GraphSAGE model, we first compute the representation of all nodes on the 1st GraphSAGE layer that takes all neighbors into account. After all the representations from the 1st GraphSAGE layer are computed, we start from there and compute the representation of all nodes on the 2nd GraphSAGE layer.  We repeat the process until we go to the last layer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def inference_with_sagenet(sagenet, graph, input_features, batch_size):\n",
    "    block_sampler = NeighborSampler(graph, [graph.number_of_nodes()])\n",
    "    h = input_features\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        # We are computing all representations of one layer at a time.\n",
    "        # The outer loop iterates over GNN layers.\n",
    "        for conv in sagenet.convs:\n",
    "            new_h_list = []\n",
    "            node_ids = torch.arange(graph.number_of_nodes())\n",
    "            # The inner loop iterates over batch of nodes.\n",
    "            for batch_start in range(0, graph.number_of_nodes(), batch_size):\n",
    "                # Sample a block with full neighbors of the current node batch\n",
    "                block = block_sampler.sample(node_ids[batch_start:batch_start+batch_size])[0]\n",
    "                # Get the necessary input node IDs for this node batch on this layer\n",
    "                input_node_ids = block.srcdata[dgl.NID]\n",
    "                # Get the input features\n",
    "                h_input = h[input_node_ids]\n",
    "                # Compute the output of this node batch on this layer\n",
    "                new_h = sagenet.propagate(block, h_input, conv)\n",
    "                new_h_list.append(new_h)\n",
    "            # We finished computing all representations on this layer.  We need to compute the\n",
    "            # representations of next layer.\n",
    "            h = torch.cat(new_h_list)\n",
    "        \n",
    "    return h"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load Dataset\n",
    "\n",
    "Load a builtin dataset from DGL: a citation network of pubmed, where nodes are papers and edges are citations.\n",
    "\n",
    "DGL provides many builtin datasets [here](https://doc.dgl.ai/api/python/data.html)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import dgl.data\n",
    "\n",
    "dataset = dgl.data.citation_graph.load_pubmed()\n",
    "\n",
    "# Set features and labels for each node\n",
    "graph = dgl.graph(dataset.graph)\n",
    "graph.ndata['features'] = torch.FloatTensor(dataset.features)\n",
    "graph.ndata['labels'] = torch.LongTensor(dataset.labels)\n",
    "in_feats = dataset.features.shape[1]\n",
    "num_labels = dataset.num_labels\n",
    "\n",
    "# Find the node IDs in the training, validation, and test set.\n",
    "train_nid = dataset.train_mask.nonzero()[0]\n",
    "val_nid = dataset.val_mask.nonzero()[0]\n",
    "test_nid = dataset.test_mask.nonzero()[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The [Amazon product co-purchasing network](https://ogb.stanford.edu/docs/nodeprop/#ogbn-products) from [Open Graph Benchmark](https://ogb.stanford.edu/)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ogb.nodeproppred import DglNodePropPredDataset\n",
    "\n",
    "data = DglNodePropPredDataset(name='ogbn-products')\n",
    "splitted_idx = data.get_idx_split()\n",
    "graph, labels = data[0]\n",
    "labels = labels[:, 0]\n",
    "graph = dgl.as_heterograph(graph)\n",
    "\n",
    "graph.ndata['features'] = graph.ndata['feat']\n",
    "graph.ndata['labels'] = labels\n",
    "in_feats = graph.ndata['features'].shape[1]\n",
    "num_labels = len(torch.unique(labels))\n",
    "\n",
    "# Find the node IDs in the training, validation, and test set.\n",
    "train_nid, val_nid, test_nid = splitted_idx['train'], splitted_idx['valid'], splitted_idx['test']\n",
    "print('|V|={}, |E|={}'.format(graph.number_of_nodes(), graph.number_of_edges()))\n",
    "print('train: {}, valid: {}, test: {}'.format(len(train_nid), len(val_nid), len(test_nid)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define Neighbor Sampler\n",
    "\n",
    "We can reuse our neighbor sampler code above."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "neighbor_sampler = NeighborSampler(graph, [10, 25])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define DataLoader\n",
    "\n",
    "PyTorch generates minibatches with a `DataLoader` object.  We can also use it.\n",
    "\n",
    "Note that to compute the output of a minibatch of nodes, we need a list of blocks described as above.  Therefore, we need to change the `collate_fn` argument which defines how to compose different individual examples into a minibatch.\n",
    "\n",
    "The benefit of using Pytorch Dataloader is that we can take advantage of multiprocessing in DataLoader to generate mini-batches in parallel."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.utils.data\n",
    "\n",
    "BATCH_SIZE = 1000\n",
    "\n",
    "train_dataloader = torch.utils.data.DataLoader(\n",
    "    train_nid, batch_size=BATCH_SIZE, collate_fn=neighbor_sampler.sample, shuffle=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define Model and Optimizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "HIDDEN_FEATURES = 50\n",
    "model = SAGENet(2, in_feats, num_labels, HIDDEN_FEATURES)\n",
    "\n",
    "opt = torch.optim.Adam(model.parameters(), lr=1e-3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_accuracy(pred, labels):\n",
    "    return (pred.argmax(1) == labels).float().mean().item()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Training Loop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "NUM_EPOCHS = 200\n",
    "EVAL_BATCH_SIZE = 10000\n",
    "for epoch in range(NUM_EPOCHS):\n",
    "    model.train()\n",
    "    for blocks in train_dataloader:\n",
    "        input_nodes = blocks[0].srcdata[dgl.NID]\n",
    "        output_nodes = blocks[-1].dstdata[dgl.NID]\n",
    "        \n",
    "        input_features = graph.ndata['features'][input_nodes]\n",
    "        output_labels = graph.ndata['labels'][output_nodes]\n",
    "        \n",
    "        output_predictions = model(blocks, input_features)\n",
    "        loss = F.cross_entropy(output_predictions, output_labels)\n",
    "        opt.zero_grad()\n",
    "        loss.backward()\n",
    "        opt.step()\n",
    "\n",
    "    if (epoch + 1) % 5 == 0:\n",
    "        model.eval()\n",
    "        all_predictions = inference_with_sagenet(model, graph, graph.ndata['features'], EVAL_BATCH_SIZE)\n",
    "\n",
    "        val_predictions = all_predictions[val_nid]\n",
    "        val_labels = graph.ndata['labels'][val_nid]\n",
    "        test_predictions = all_predictions[test_nid]\n",
    "        test_labels = graph.ndata['labels'][test_nid]\n",
    "\n",
    "        print('Validation acc:', compute_accuracy(val_predictions, val_labels),\n",
    "              'Test acc:', compute_accuracy(test_predictions, test_labels))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
